function plot_basis_function_2d(basis_type, flag)

dN = reference_basis_function_2d(basis_type);

if nargin == 1
    Nb = size(dN, 1);
else
    Nb = min(flag, size(dN, 1));
end

switch basis_type
    case {"P1", "P1b", "P2"}
        triangle_region = @(x,y) (x >= 0) & (x <= 1) & (y >= 0) & (y <= 1) & (y <= -x+1);
        [x, y] = meshgrid(linspace(0, 1, 100), linspace(0, 1, 100));
        for i = 1:Nb
            z = dN{i}(x, y);
            z(~triangle_region(x, y)) = NaN;
            surf(x, y, z);
            hold on;
        end
    case {"Q1", "Q1b", "Q2"}
        [x, y] = meshgrid(linspace(-1, 1, 100), linspace(-1, 1, 100));
        for i = 1:Nb
            z = dN{i}(x, y);
            surf(x, y, z);
            hold on;
        end
    otherwise
        error("Invalid basis type");
end

end